/*
 *   This program is free software: you can redistribute it and/or modify
 *   it under the terms of the GNU General Public License as published by
 *   the Free Software Foundation, either version 3 of the License, or
 *   (at your option) any later version.
 *
 *   This program is distributed in the hope that it will be useful,
 *   but WITHOUT ANY WARRANTY; without even the implied warranty of
 *   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *   GNU General Public License for more details.
 *
 *   You should have received a copy of the GNU General Public License
 *   along with this program.  If not, see <http://www.gnu.org/licenses/>.
 */

/*
 * Copyright (C) 2002 University of Waikato 
 */

package weka.classifiers.meta;

import java.io.BufferedReader;
import java.io.InputStreamReader;
import java.util.ArrayList;

import junit.framework.Test;
import junit.framework.TestSuite;
import weka.classifiers.AbstractClassifierTest;
import weka.classifiers.Classifier;
import weka.classifiers.evaluation.EvaluationUtils;
import weka.classifiers.evaluation.NominalPrediction;
import weka.classifiers.evaluation.Prediction;
import weka.core.Attribute;
import weka.core.Instances;
import weka.core.NoSupportForMissingValuesException;
import weka.core.SelectedTag;
import weka.core.Tag;
import weka.core.UnsupportedAttributeTypeException;
import weka.filters.Filter;
import weka.filters.unsupervised.attribute.RemoveType;
import weka.filters.unsupervised.attribute.ReplaceMissingValues;

/**
 * Tests ThresholdSelector. Run from the command line with:
 * <p>
 * java weka.classifiers.meta.ThresholdSelectorTest
 * 
 * @author <a href="mailto:len@reeltwo.com">Len Trigg</a>
 * @author FracPete (fracpete at waikato dot ac dot nz)
 * @version $Revision$
 */
public class ThresholdSelectorTest extends AbstractClassifierTest {

  private static double[] DIST1 = new double[] { 0.25, 0.375, 0.5, 0.625, 0.75,
    0.875, 1.0 };

  /** A set of instances to test with */
  protected transient Instances m_Instances;

  /** Used to generate various types of predictions */
  protected transient EvaluationUtils m_Evaluation;

  public ThresholdSelectorTest(String name) {
    super(name);
  }

  /**
   * Called by JUnit before each test method. This implementation creates the
   * default classifier to test and loads a test set of Instances.
   * 
   * @exception Exception if an error occurs reading the example instances.
   */
  @Override
  protected void setUp() throws Exception {
    super.setUp();

    m_Evaluation = new EvaluationUtils();
    m_Instances = new Instances(
      new BufferedReader(
        new InputStreamReader(
          ClassLoader
            .getSystemResourceAsStream("weka/classifiers/data/ClassifierTest.arff"))));
  }

  /** Creates a default ThresholdSelector */
  @Override
  public Classifier getClassifier() {
    return getClassifier(DIST1);
  }

  /** Called by JUnit after each test method */
  @Override
  protected void tearDown() {
    super.tearDown();

    m_Evaluation = null;
  }

  /**
   * Creates a ThresholdSelector that returns predictions from a given
   * distribution
   */
  public Classifier getClassifier(double[] dist) {
    return getClassifier(new ThresholdSelectorDummyClassifier(dist));
  }

  /**
   * Creates a ThresholdSelector with the given subclassifier.
   * 
   * @param classifier a <code>Classifier</code> to use as the subclassifier
   * @return a new <code>ThresholdSelector</code>
   */
  public Classifier getClassifier(Classifier classifier) {
    ThresholdSelector t = new ThresholdSelector();
    t.setClassifier(classifier);
    return t;
  }

  /**
   * Builds a model using the current classifier using the first half of the
   * current data for training, and generates a bunch of predictions using the
   * remaining half of the data for testing.
   * 
   * @return a <code>ArrayList</code> containing the predictions.
   */
  protected ArrayList<Prediction> useClassifier() throws Exception {

    Classifier dc = null;
    int tot = m_Instances.numInstances();
    int mid = tot / 2;
    Instances train = null;
    Instances test = null;
    try {
      train = new Instances(m_Instances, 0, mid);
      test = new Instances(m_Instances, mid, tot - mid);
      dc = m_Classifier;
    } catch (Exception ex) {
      ex.printStackTrace();
      fail("Problem setting up to use classifier: " + ex);
    }
    int counter = 0;
    do {
      try {
        return m_Evaluation.getTrainTestPredictions(dc, train, test);
      } catch (UnsupportedAttributeTypeException ex) {
        SelectedTag tag = null;
        boolean invert = false;
        String msg = ex.getMessage();
        if ((msg.indexOf("string") != -1) && (msg.indexOf("attributes") != -1)) {
          System.err.println("\nDeleting string attributes.");
          tag = new SelectedTag(Attribute.STRING, RemoveType.TAGS_ATTRIBUTETYPE);
        } else if ((msg.indexOf("only") != -1)
          && (msg.indexOf("nominal") != -1)) {
          System.err.println("\nDeleting non-nominal attributes.");
          tag = new SelectedTag(Attribute.NOMINAL,
            RemoveType.TAGS_ATTRIBUTETYPE);
          invert = true;
        } else if ((msg.indexOf("only") != -1)
          && (msg.indexOf("numeric") != -1)) {
          System.err.println("\nDeleting non-numeric attributes.");
          tag = new SelectedTag(Attribute.NUMERIC,
            RemoveType.TAGS_ATTRIBUTETYPE);
          invert = true;
        } else {
          throw ex;
        }
        RemoveType attFilter = new RemoveType();
        attFilter.setAttributeType(tag);
        attFilter.setInvertSelection(invert);
        attFilter.setInputFormat(train);
        train = Filter.useFilter(train, attFilter);
        attFilter.batchFinished();
        test = Filter.useFilter(test, attFilter);
        counter++;
        if (counter > 2) {
          throw ex;
        }
      } catch (NoSupportForMissingValuesException ex2) {
        System.err.println("\nReplacing missing values.");
        ReplaceMissingValues rmFilter = new ReplaceMissingValues();
        rmFilter.setInputFormat(train);
        train = Filter.useFilter(train, rmFilter);
        rmFilter.batchFinished();
        test = Filter.useFilter(test, rmFilter);
      } catch (IllegalArgumentException ex3) {
        String msg = ex3.getMessage();
        if (msg.indexOf("Not enough instances") != -1) {
          System.err.println("\nInflating training data.");
          Instances trainNew = new Instances(train);
          for (int i = 0; i < train.numInstances(); i++) {
            trainNew.add(train.instance(i));
          }
          train = trainNew;
        } else {
          throw ex3;
        }
      }
    } while (true);
  }

  public void testRangeNone() throws Exception {

    int cind = 0;
    ((ThresholdSelector) m_Classifier).setDesignatedClass(new SelectedTag(
      ThresholdSelector.OPTIMIZE_0, ThresholdSelector.TAGS_OPTIMIZE));
    ((ThresholdSelector) m_Classifier).setRangeCorrection(new SelectedTag(
      ThresholdSelector.RANGE_NONE, ThresholdSelector.TAGS_RANGE));
    ArrayList<Prediction> result = null;
    m_Instances.setClassIndex(1);
    result = useClassifier();
    assertTrue(result.size() != 0);
    double minp = 0;
    double maxp = 0;
    for (int i = 0; i < result.size(); i++) {
      NominalPrediction p = (NominalPrediction) result.get(i);
      double prob = p.distribution()[cind];
      if ((i == 0) || (prob < minp)) {
        minp = prob;
      }
      if ((i == 0) || (prob > maxp)) {
        maxp = prob;
      }
    }
    assertTrue("Upper limit shouldn't increase", maxp <= 1.0);
    assertTrue("Lower limit shouldn'd decrease", minp >= 0.25);
  }

  public void testDesignatedClass() throws Exception {

    for (Tag element : ThresholdSelector.TAGS_OPTIMIZE) {
      ((ThresholdSelector) m_Classifier).setDesignatedClass(new SelectedTag(
        element.getID(), ThresholdSelector.TAGS_OPTIMIZE));
      m_Instances.setClassIndex(1);
      ArrayList<Prediction> result = useClassifier();
      assertTrue(result.size() != 0);
    }
  }

  public void testEvaluationMode() throws Exception {

    for (Tag element : ThresholdSelector.TAGS_EVAL) {
      ((ThresholdSelector) m_Classifier).setEvaluationMode(new SelectedTag(
        element.getID(), ThresholdSelector.TAGS_EVAL));
      m_Instances.setClassIndex(1);
      ArrayList<Prediction> result = useClassifier();
      assertTrue(result.size() != 0);
    }
  }

  public void testNumXValFolds() throws Exception {

    try {
      ((ThresholdSelector) m_Classifier).setNumXValFolds(0);
      fail("Expected IllegalArgumentException");
    } catch (IllegalArgumentException e) {
      // OK
    }

    for (int i = 2; i < 20; i += 2) {
      ((ThresholdSelector) m_Classifier).setNumXValFolds(i);
      m_Instances.setClassIndex(1);
      ArrayList<Prediction> result = useClassifier();
      assertTrue(result.size() != 0);
    }
  }

  public static Test suite() {
    return new TestSuite(ThresholdSelectorTest.class);
  }

  public static void main(String[] args) {
    junit.textui.TestRunner.run(suite());
  }
}
